{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating, Training and Evaluating Uplift Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we'll demonstrate how to create, train and evaluate uplift models and apply uplift modelling technique.\n",
    "\n",
    "- What is uplift modelling?\n",
    "\n",
    "    It is a family of causal inference technology that uses machine learning models to estimate the causal impact of some treatment on an individual's behaviour.\n",
    "\n",
    "    - **Persuadables** will only respond positive to the treatment\n",
    "    - **Sleeping-dogs** have a strong negative response to the treatment\n",
    "    - **Lost Causes** will never reach the outcome even with the treatment\n",
    "    - **Sure Things** will always reach the outcome with or without the treatment\n",
    "\n",
    "    The goal of uplift modelling is to identify the \"persuadables\", not waste efforts on \"sure things\" and \"lost causes\", and avoid bothering \"sleeping dogs\"\n",
    "\n",
    "- How does uplift modelling work?\n",
    "    - **Meta Learner**: predicts the difference between an individual's behaviour when there is a treatment and when there is no treatment\n",
    "\n",
    "    - **Uplift Tree**: a tree-based algorithm where the splitting criterion is based on differences in uplift\n",
    "\n",
    "    - **NN-based Model**：a neural network model that usually works with observational data\n",
    "\n",
    "- Where can uplift modelling work?\n",
    "    - Marketing: help to identify persuadables to apply a treatment such as a coupon or an online advertisement\n",
    "    - Medical Treatment: help to understand how a treatment can impact certain groups differently\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 1: Load the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Notebook Configurations\n",
    "\n",
    "By defining below parameters, we can apply this notebook on different datasets easily."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "IS_CUSTOM_DATA = False  # if True, dataset has to be uploaded manually by user\n",
    "DATA_FOLDER = \"Files/uplift-modelling\"\n",
    "DATA_FILE = \"criteo-research-uplift-v2.1.csv\"\n",
    "\n",
    "# data schema\n",
    "FEATURE_COLUMNS = [f\"f{i}\" for i in range(12)]\n",
    "TREATMENT_COLUMN = \"treatment\"\n",
    "LABEL_COLUMN = \"visit\"\n",
    "\n",
    "EXPERIMENT_NAME = \"aisample-upliftmodelling\"  # mlflow experiment name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Import dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyspark.sql.functions as F\n",
    "from pyspark.sql.window import Window\n",
    "from pyspark.sql.types import *\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.style as style\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "from synapse.ml.featurize import Featurize\n",
    "from synapse.ml.core.spark import FluentAPI\n",
    "from synapse.ml.lightgbm import *\n",
    "from synapse.ml.train import ComputeModelStatistics\n",
    "\n",
    "import os\n",
    "import gzip\n",
    "\n",
    "import mlflow\n",
    "import trident.mlflow\n",
    "from trident.mlflow import get_sds_url"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Download dataset and upload to lakehouse\n",
    "\n",
    "**Please add a lakehouse to the notebook before running it.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "- Dataset description: This dataset was created by The Criteo AI Lab.The dataset consists of 13M rows, each one representing a user with 12 features, a treatment indicator and 2 binary labels (visits and conversions).\n",
    "    - f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11: feature values (dense, float)\n",
    "    - treatment: treatment group (1 = treated, 0 = control) which indicates if a customer was targeted by advertising randomly\n",
    "    - conversion: whether a conversion occured for this user (binary, label)\n",
    "    - visit: whether a visit occured for this user (binary, label)\n",
    "\n",
    "- Dataset homepage: https://ailab.criteo.com/criteo-uplift-prediction-dataset/\n",
    "\n",
    "- Citation:\n",
    "    ```\n",
    "    @inproceedings{Diemert2018,\n",
    "    author = {{Diemert Eustache, Betlei Artem} and Renaudin, Christophe and Massih-Reza, Amini},\n",
    "    title={A Large Scale Benchmark for Uplift Modeling},\n",
    "    publisher = {ACM},\n",
    "    booktitle = {Proceedings of the AdKDD and TargetAd Workshop, KDD, London,United Kingdom, August, 20, 2018},\n",
    "    year = {2018}\n",
    "    }\n",
    "    ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "if not IS_CUSTOM_DATA:\n",
    "    # Download demo data files into lakehouse if not exist\n",
    "    import os, requests\n",
    "\n",
    "    remote_url = \"http://go.criteo.net/criteo-research-uplift-v2.1.csv.gz\"\n",
    "    download_file = \"criteo-research-uplift-v2.1.csv.gz\"\n",
    "    download_path = f\"/lakehouse/default/{DATA_FOLDER}/raw\"\n",
    "\n",
    "    if not os.path.exists(\"/lakehouse/default\"):\n",
    "        raise FileNotFoundError(\"Default lakehouse not found, please add a lakehouse.\")\n",
    "    os.makedirs(download_path, exist_ok=True)\n",
    "    if not os.path.exists(f\"{download_path}/{DATA_FILE}\"):\n",
    "        r = requests.get(f\"{remote_url}\", timeout=30)\n",
    "        with open(f\"{download_path}/{download_file}\", \"wb\") as f:\n",
    "            f.write(r.content)\n",
    "        with gzip.open(f\"{download_path}/{download_file}\", \"rb\") as fin:\n",
    "            with open(f\"{download_path}/{DATA_FILE}\", \"wb\") as fout:\n",
    "                fout.write(fin.read())\n",
    "    print(\"Downloaded demo data files into lakehouse.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Read data from lakehouse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "raw_df = spark.read.csv(\n",
    "    f\"{DATA_FOLDER}/raw/{DATA_FILE}\", header=True, inferSchema=True\n",
    ").cache()\n",
    "\n",
    "display(raw_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 2: Prepare the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Data exploration\n",
    "\n",
    "- **The overall rate of users that visit/convert**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"cbdb21d2-1420-495e-a169-c60f8314af7f\",\"activityId\":\"3806e4c7-a95e-44fb-b0e8-de2a0203b554\",\"applicationId\":\"application_1661918651252_0001\",\"jobGroupId\":\"9\",\"advices\":{\"warn\":1}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "raw_df.select(\n",
    "    F.mean(\"visit\").alias(\"Percentage of users that visit\"),\n",
    "    F.mean(\"conversion\").alias(\"Percentage of users that convert\"),\n",
    "    (F.sum(\"conversion\") / F.sum(\"visit\")).alias(\"Percentage of visitors that convert\"),\n",
    ").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "- **The overall average treatment effect on visit**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"cbdb21d2-1420-495e-a169-c60f8314af7f\",\"activityId\":\"3806e4c7-a95e-44fb-b0e8-de2a0203b554\",\"applicationId\":\"application_1661918651252_0001\",\"jobGroupId\":\"10\",\"advices\":{\"warn\":1}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "raw_df.groupby(\"treatment\").agg(\n",
    "    F.mean(\"visit\").alias(\"Mean of visit\"),\n",
    "    F.sum(\"visit\").alias(\"Sum of visit\"),\n",
    "    F.count(\"visit\").alias(\"Count\"),\n",
    ").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "- **The overall average treatment effect on conversion**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"cbdb21d2-1420-495e-a169-c60f8314af7f\",\"activityId\":\"3806e4c7-a95e-44fb-b0e8-de2a0203b554\",\"applicationId\":\"application_1661918651252_0001\",\"jobGroupId\":\"11\",\"advices\":{\"warn\":1}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "raw_df.groupby(\"treatment\").agg(\n",
    "    F.mean(\"conversion\").alias(\"Mean of conversion\"),\n",
    "    F.sum(\"conversion\").alias(\"Sum of conversion\"),\n",
    "    F.count(\"conversion\").alias(\"Count\"),\n",
    ").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Split train-test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "transformer = (\n",
    "    Featurize().setOutputCol(\"features\").setInputCols(FEATURE_COLUMNS).fit(raw_df)\n",
    ")\n",
    "\n",
    "df = transformer.transform(raw_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)\n",
    "\n",
    "print(\"Size of train dataset: %d\" % train_df.count())\n",
    "print(\"Size of test dataset: %d\" % test_df.count())\n",
    "\n",
    "train_df.groupby(TREATMENT_COLUMN).count().show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Split treatment-control dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "treatment_train_df = train_df.where(f\"{TREATMENT_COLUMN} > 0\")\n",
    "control_train_df = train_df.where(f\"{TREATMENT_COLUMN} = 0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 3: Model Training and Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Uplift Modelling: T-Learner with LightGBM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "classifier = (\n",
    "    LightGBMClassifier()\n",
    "    .setFeaturesCol(\"features\")\n",
    "    .setNumLeaves(10)\n",
    "    .setNumIterations(100)\n",
    "    .setObjective(\"binary\")\n",
    "    .setLabelCol(LABEL_COLUMN)\n",
    ")\n",
    "\n",
    "treatment_model = classifier.fit(treatment_train_df)\n",
    "control_model = classifier.fit(control_train_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Predict on test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "getPred = F.udf(lambda v: float(v[1]), FloatType())\n",
    "\n",
    "test_pred_df = (\n",
    "    test_df.mlTransform(treatment_model)\n",
    "    .withColumn(\"treatment_pred\", getPred(\"probability\"))\n",
    "    .drop(\"rawPrediction\", \"probability\", \"prediction\")\n",
    "    .mlTransform(control_model)\n",
    "    .withColumn(\"control_pred\", getPred(\"probability\"))\n",
    "    .drop(\"rawPrediction\", \"probability\", \"prediction\")\n",
    "    .withColumn(\"pred_uplift\", F.col(\"treatment_pred\") - F.col(\"control_pred\"))\n",
    "    .select(\n",
    "        TREATMENT_COLUMN, LABEL_COLUMN, \"treatment_pred\", \"control_pred\", \"pred_uplift\"\n",
    "    )\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(test_pred_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Model evaluation\n",
    "\n",
    "Since actial uplift cannot be observed for each individual, we measure the uplift over a group of customers.\n",
    "\n",
    "- **Uplift Curve**: plots the real cumulative uplift across the population\n",
    "\n",
    "First, we rank the test dataframe order by the predict uplift."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "test_ranked_df = test_pred_df.withColumn(\n",
    "    \"percent_rank\", F.percent_rank().over(Window.orderBy(F.desc(\"pred_uplift\")))\n",
    ")\n",
    "\n",
    "display(test_ranked_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we calculate the culmulative percentage of visits in each group (treatment or control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C = test_ranked_df.where(f\"{TREATMENT_COLUMN} == 0\").count()\n",
    "T = test_ranked_df.where(f\"{TREATMENT_COLUMN} != 0\").count()\n",
    "\n",
    "test_ranked_df = (\n",
    "    test_ranked_df.withColumn(\n",
    "        \"control_label\",\n",
    "        F.when(F.col(TREATMENT_COLUMN) == 0, F.col(LABEL_COLUMN)).otherwise(0),\n",
    "    )\n",
    "    .withColumn(\n",
    "        \"treatment_label\",\n",
    "        F.when(F.col(TREATMENT_COLUMN) != 0, F.col(LABEL_COLUMN)).otherwise(0),\n",
    "    )\n",
    "    .withColumn(\n",
    "        \"control_cumsum\",\n",
    "        F.sum(\"control_label\").over(Window.orderBy(\"percent_rank\")) / C,\n",
    "    )\n",
    "    .withColumn(\n",
    "        \"treatment_cumsum\",\n",
    "        F.sum(\"treatment_label\").over(Window.orderBy(\"percent_rank\")) / T,\n",
    "    )\n",
    ")\n",
    "\n",
    "display(test_ranked_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we calcualte the group's uplift at each percentage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ranked_df = test_ranked_df.withColumn(\n",
    "    \"group_uplift\", F.col(\"treatment_cumsum\") - F.col(\"control_cumsum\")\n",
    ").cache()\n",
    "\n",
    "display(test_ranked_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "Now we can plot the uplift curve on the prediction of the test dataset. We need to convert the pyspark dataframe to pandas dataframe before plotting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "def uplift_plot(uplift_df):\n",
    "    \"\"\"\n",
    "    Plot the uplift curve\n",
    "    \"\"\"\n",
    "    gain_x = uplift_df.percent_rank\n",
    "    gain_y = uplift_df.group_uplift\n",
    "    # plot the data\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    mpl.rcParams[\"font.size\"] = 8\n",
    "\n",
    "    ax = plt.plot(gain_x, gain_y, color=\"#2077B4\", label=\"Normalized Uplift Model\")\n",
    "\n",
    "    plt.plot(\n",
    "        [0, gain_x.max()],\n",
    "        [0, gain_y.max()],\n",
    "        \"--\",\n",
    "        color=\"tab:orange\",\n",
    "        label=\"Random Treatment\",\n",
    "    )\n",
    "    plt.legend()\n",
    "    plt.xlabel(\"Porportion Targeted\")\n",
    "    plt.ylabel(\"Uplift\")\n",
    "    plt.grid(b=True, which=\"major\")\n",
    "\n",
    "    return ax\n",
    "\n",
    "\n",
    "test_ranked_pd_df = test_ranked_df.select(\n",
    "    [\"pred_uplift\", \"percent_rank\", \"group_uplift\"]\n",
    ").toPandas()\n",
    "uplift_plot(test_ranked_pd_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![criteo_uplift_curve.jpeg](https://mmlspark.blob.core.windows.net/graphics/notebooks/criteo_uplift_curve.jpeg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "From the uplift curve above, we notice that the top 20% population ranked by our prediction have a large gain if they were given the treatment, which means the are the **persuadables**. Therefore, we can print the cutoff score at 20% percentage to identify the target customers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "cutoff_percentage = 0.2\n",
    "cutoff_score = test_ranked_pd_df.iloc[int(len(test_ranked_pd_df) * cutoff_percentage)][\n",
    "    \"pred_uplift\"\n",
    "]\n",
    "\n",
    "print(\"Uplift score higher than {:.4f} are Persuadables\".format(cutoff_score))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log and Load Model with MLFlow\n",
    "Now that we have a trained model, we can save it for later use. Here we use mlflow to log metrics and models, \. We can also use this API to load models for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup mlflow\n",
    "mlflow.set_tracking_uri(get_sds_url())\n",
    "mlflow.set_registry_uri(get_sds_url())\n",
    "mlflow.set_experiment(EXPERIMENT_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# log model, metrics and params\n",
    "with mlflow.start_run() as run:\n",
    "    print(\"log model:\")\n",
    "    mlflow.spark.log_model(\n",
    "        treatment_model,\n",
    "        f\"{EXPERIMENT_NAME}-treatmentmodel\",\n",
    "        registered_model_name=f\"{EXPERIMENT_NAME}-treatmentmodel\",\n",
    "        dfs_tmpdir=\"Files/spark\",\n",
    "    )\n",
    "\n",
    "    mlflow.spark.log_model(\n",
    "        control_model,\n",
    "        f\"{EXPERIMENT_NAME}-controlmodel\",\n",
    "        registered_model_name=f\"{EXPERIMENT_NAME}-controlmodel\",\n",
    "        dfs_tmpdir=\"Files/spark\",\n",
    "    )\n",
    "\n",
    "    model_uri = f\"runs:/{run.info.run_id}/{EXPERIMENT_NAME}\"\n",
    "    print(\"Model saved in run %s\" % run.info.run_id)\n",
    "    print(f\"Model URI: {model_uri}-treatmentmodel\")\n",
    "    print(f\"Model URI: {model_uri}-controlmodel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model back\n",
    "loaded_treatmentmodel = mlflow.spark.load_model(\n",
    "    f\"{model_uri}-treatmentmodel\", dfs_tmpdir=\"Files/spark\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernel_info": {
   "name": "synapse_pyspark"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "notebook_environment": {},
  "save_output": true,
  "spark_compute": {
   "compute_id": "/trident/default",
   "session_options": {
    "conf": {
     "spark.livy.synapse.ipythonInterpreter.enabled": "true"
    },
    "enableDebugMode": false,
    "keepAliveTimeout": 30
   }
  },
  "trident": {
   "lakehouse": {
    "default_lakehouse": "4eb0cf6c-9e08-4640-bc5c-206cba1864b2",
    "known_lakehouses": [
     {
      "id": "4eb0cf6c-9e08-4640-bc5c-206cba1864b2"
     }
    ]
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "8cebba326b76ca708172f0a6a24a89689a3b64f83dbd9353b827f2f4b33d3f80"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
